# Copyright (c) 2024, Tri Dao, Albert Gu.
"""
this code is modified from https://github.com/state-spaces/mamba/tree/main/mamba_ssm/ops/triton
"""
"""We want triton==2.1.0 or 2.2.0 for this"""

import paddle
import paddle.nn.functional as F
import triton
import triton.language as tl
from einops import rearrange

from ...utils import custom_bwd, custom_fwd


@triton.paddle_autotune(
    configs=[
        triton.Config({"BLOCK_SIZE": 64}),
        triton.Config({"BLOCK_SIZE": 128}),
        triton.Config({"BLOCK_SIZE": 256}),
        triton.Config({"BLOCK_SIZE": 512}),
        triton.Config({"BLOCK_SIZE": 1024}),
        triton.Config({"BLOCK_SIZE": 2048}),
    ],
    key=["dim"],
)
@triton.jit
def _state_passing_fwd_kernel(
    # Pointers to matrices
    states_ptr,
    out_ptr,
    final_states_ptr,
    dA_cs_ptr,
    initstates_ptr,
    seq_idx_ptr,
    # Matrix dimensions
    dim,
    nchunks,
    seqlen,
    chunk_size,
    # Strides
    stride_states_batch,
    stride_states_chunk,
    stride_states_head,
    stride_states_dim,
    stride_out_batch,
    stride_out_chunk,
    stride_out_head,
    stride_out_dim,
    stride_final_states_batch,
    stride_final_states_head,
    stride_final_states_dim,
    stride_dA_cs_batch,
    stride_dA_cs_chunk,
    stride_dA_cs_head,
    stride_initstates_batch,
    stride_initstates_head,
    stride_initstates_dim,
    stride_seq_idx_batch,
    stride_seq_idx_seqlen,
    # Meta-parameters
    HAS_INITSTATES: tl.constexpr,
    HAS_SEQ_IDX: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    pid_b = tl.program_id(axis=1)
    pid_h = tl.program_id(axis=2)
    pid_m = tl.program_id(axis=0)
    states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
    dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head
    out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
    final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head
    if HAS_INITSTATES:
        initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head
    if HAS_SEQ_IDX:
        seq_idx_ptr += pid_b * stride_seq_idx_batch

    offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    states_ptrs = states_ptr + offs_m * stride_states_dim
    out_ptrs = out_ptr + offs_m * stride_out_dim
    final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim

    if not HAS_INITSTATES:
        states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
    else:
        initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim
        states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
    tl.store(out_ptrs, states, mask=offs_m < dim)
    out_ptrs += stride_out_chunk
    seq_idx = 0
    for c in range(nchunks):
        new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
        dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
        scale = tl.exp(dA_cs)
        if HAS_SEQ_IDX:
            seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)
            scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
            seq_idx = seq_idx_new
        states = scale * states + new_states
        if c < nchunks - 1:
            tl.store(out_ptrs, states, mask=offs_m < dim)
        else:
            tl.store(final_states_ptrs, states, mask=offs_m < dim)
        states_ptrs += stride_states_chunk
        dA_cs_ptr += stride_dA_cs_chunk
        out_ptrs += stride_out_chunk


@triton.paddle_autotune(
    configs=[
        triton.Config({"BLOCK_SIZE": 64}),
        triton.Config({"BLOCK_SIZE": 128}),
        triton.Config({"BLOCK_SIZE": 256}),
        triton.Config({"BLOCK_SIZE": 512}),
        triton.Config({"BLOCK_SIZE": 1024}),
        triton.Config({"BLOCK_SIZE": 2048}),
    ],
    key=["dim"],
)
@triton.jit
def _state_passing_bwd_kernel(
    # Pointers to matrices
    dout_ptr,
    out_ptr,
    dA_cs_ptr,
    dfinal_states_ptr,
    seq_idx_ptr,
    dstates_ptr,
    ddA_cs_ptr,
    dinitstates_ptr,
    states_converted_ptr,
    # Matrix dimensions
    dim,
    nchunks,
    seqlen,
    chunk_size,
    # Strides
    stride_dout_batch,
    stride_dout_chunk,
    stride_dout_head,
    stride_dout_dim,
    stride_out_batch,
    stride_out_chunk,
    stride_out_head,
    stride_out_dim,
    stride_dA_cs_batch,
    stride_dA_cs_chunk,
    stride_dA_cs_head,
    stride_dfinal_states_batch,
    stride_dfinal_states_head,
    stride_dfinal_states_dim,
    stride_seq_idx_batch,
    stride_seq_idx_seqlen,
    stride_dstates_batch,
    stride_dstates_chunk,
    stride_dstates_head,
    stride_dstates_dim,
    stride_ddA_cs_batch,
    stride_ddA_cs_chunk,
    stride_ddA_cs_head,
    stride_dinitstates_batch,
    stride_dinitstates_head,
    stride_dinitstates_dim,
    # Meta-parameters
    CONVERT_STATES: tl.constexpr,
    HAS_DFINAL_STATES: tl.constexpr,
    HAS_DINITSTATES: tl.constexpr,
    HAS_SEQ_IDX: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    pid_b = tl.program_id(axis=1)
    pid_h = tl.program_id(axis=2)
    pid_m = tl.program_id(axis=0)
    dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk
    dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk
    ddA_cs_ptr += (
        pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m
    )
    out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk
    dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk
    if CONVERT_STATES:
        states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk
    if HAS_DFINAL_STATES:
        dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head
    if HAS_DINITSTATES:
        dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head
    if HAS_SEQ_IDX:
        seq_idx_ptr += pid_b * stride_seq_idx_batch

    offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim
    out_ptrs = out_ptr + offs_m * stride_out_dim
    dout_ptrs = dout_ptr + offs_m * stride_dout_dim
    if CONVERT_STATES:
        states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim

    if HAS_DFINAL_STATES:
        dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(
            tl.float32
        )
    else:
        dstates = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
    tl.store(dstates_ptrs, dstates, mask=offs_m < dim)
    if HAS_SEQ_IDX:
        seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen)
    dstates_ptrs -= stride_dstates_chunk
    for c in range(nchunks - 1):
        dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
        scale = tl.exp(dA_cs)
        if HAS_SEQ_IDX:
            seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen))
            scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
            seq_idx = seq_idx_new
        out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
        if CONVERT_STATES:
            tl.store(states_converted_ptrs, out, mask=offs_m < dim)
        ddA = tl.sum(out * dstates) * scale
        tl.store(ddA_cs_ptr, ddA)
        dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
        dstates = scale * dstates + dout
        tl.store(dstates_ptrs, dstates, mask=offs_m < dim)
        dout_ptrs -= stride_dout_chunk
        dstates_ptrs -= stride_dstates_chunk
        dA_cs_ptr -= stride_dA_cs_chunk
        ddA_cs_ptr -= stride_ddA_cs_chunk
        out_ptrs -= stride_out_chunk
        if CONVERT_STATES:
            states_converted_ptrs -= stride_out_chunk
    if CONVERT_STATES:
        out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
        tl.store(states_converted_ptrs, out, mask=offs_m < dim)
    if not HAS_DINITSTATES:
        tl.store(ddA_cs_ptr, 0.0)
    else:
        dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
        scale = tl.exp(dA_cs)
        if HAS_SEQ_IDX:
            scale = tl.where(seq_idx == 0, scale, 0.0)
        out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
        ddA = tl.sum(out * dstates) * scale
        tl.store(ddA_cs_ptr, ddA)
        dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
        dstates = scale * dstates + dout
        tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim)


def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None, out_dtype=None):
    batch, nchunks, nheads, dim = states.shape
    assert tuple(dA_chunk_cumsum.shape) == (batch, nheads, nchunks)
    if initial_states is not None:
        assert tuple(initial_states.shape) == (batch, nheads, dim)
    if seq_idx is not None:
        assert chunk_size is not None
        seqlen = seq_idx.shape[-1]
        assert tuple(seq_idx.shape) == (batch, seqlen)
    out_dtype = states.dtype if out_dtype is None else out_dtype
    out = paddle.empty((batch, nchunks, nheads, dim), dtype=out_dtype)
    final_states = paddle.empty((batch, nheads, dim), dtype=paddle.float32)
    grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), batch, nheads)
    _state_passing_fwd_kernel[grid](
        states,
        out,
        final_states,
        dA_chunk_cumsum,
        initial_states,
        seq_idx,
        dim,
        nchunks,
        seqlen if seq_idx is not None else 0,
        chunk_size if seq_idx is not None else 0,
        states.strides[0],
        states.strides[1],
        states.strides[2],
        states.strides[3],
        out.strides[0],
        out.strides[1],
        out.strides[2],
        out.strides[3],
        final_states.strides[0],
        final_states.strides[1],
        final_states.strides[2],
        dA_chunk_cumsum.strides[0],
        dA_chunk_cumsum.strides[2],
        dA_chunk_cumsum.strides[1],
        *(
            (initial_states.strides[0], initial_states.strides[1], initial_states.strides[2])
            if initial_states is not None
            else (0, 0, 0)
        ),
        *((seq_idx.strides[0], seq_idx.strides[1]) if seq_idx is not None else (0, 0)),
        HAS_INITSTATES=initial_states is not None,
        HAS_SEQ_IDX=seq_idx is not None,
    )
    return out, final_states


def _state_passing_bwd(
    states,
    dA_chunk_cumsum,
    dout,
    dfinal_states=None,
    seq_idx=None,
    has_initial_states=None,
    dstates_dtype=None,
    states_dtype=None,
    chunk_size=None,
):
    """
    states contains the initial_states at index 0. The final states are not included in states.
    """
    batch, nchunks, nheads, dim = states.shape
    assert tuple(dA_chunk_cumsum.shape) == (batch, nheads, nchunks)
    assert tuple(dout.shape) == (batch, nchunks, nheads, dim)
    if seq_idx is not None:
        assert chunk_size is not None
        seqlen = seq_idx.shape[-1]
        assert tuple(seq_idx.shape) == (batch, seqlen)
    dstates = paddle.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)
    if states_dtype is not None and states_dtype != states.dtype:
        states_converted = paddle.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)
        assert states_converted.strides == states.strides
    else:
        states_converted = None
    if has_initial_states:
        dinitstates = paddle.empty_like(dstates[:, 0])
    else:
        dinitstates = None
    if dfinal_states is not None:
        assert tuple(dfinal_states.shape) == (batch, nheads, dim)
    BLOCK_SIZE_min = 64
    n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min
    ddA_chunk_cumsum = paddle.empty([batch, nheads, nchunks, n_blocks], dtype=paddle.float32)
    grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), batch, nheads)
    _state_passing_bwd_kernel[grid](
        dout,
        states,
        dA_chunk_cumsum,
        dfinal_states,
        seq_idx,
        dstates,
        ddA_chunk_cumsum,
        dinitstates,
        states_converted,
        dim,
        nchunks,
        seqlen if seq_idx is not None else 0,
        chunk_size if seq_idx is not None else 0,
        dout.strides[0],
        dout.strides[1],
        dout.strides[2],
        dout.strides[3],
        states.strides[0],
        states.strides[1],
        states.strides[2],
        states.strides[3],
        dA_chunk_cumsum.strides[0],
        dA_chunk_cumsum.strides[2],
        dA_chunk_cumsum.strides[1],
        *(
            (dfinal_states.strides[0], dfinal_states.strides[1], dfinal_states.strides[2])
            if dfinal_states is not None
            else (0, 0, 0)
        ),
        *((seq_idx.strides[0], seq_idx.strides[1]) if seq_idx is not None else (0, 0)),
        dstates.strides[0],
        dstates.strides[1],
        dstates.strides[2],
        dstates.strides[3],
        ddA_chunk_cumsum.strides[0],
        ddA_chunk_cumsum.strides[2],
        ddA_chunk_cumsum.strides[1],
        *(
            (dinitstates.strides[0], dinitstates.strides[1], dinitstates.strides[2])
            if dinitstates is not None
            else (0, 0, 0)
        ),
        CONVERT_STATES=states_converted is not None,
        HAS_DFINAL_STATES=dfinal_states is not None,
        HAS_DINITSTATES=dinitstates is not None,
        HAS_SEQ_IDX=seq_idx is not None,
    )
    BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs["BLOCK_SIZE"]
    n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
    ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(axis=-1).cast(dtype=dA_chunk_cumsum.dtype)
    if states_dtype is not None and states_dtype == states.dtype:
        states_converted = states
    return (
        (dstates, ddA_chunk_cumsum, dinitstates)
        if states_dtype is None
        else (dstates, ddA_chunk_cumsum, dinitstates, states_converted)
    )


class StatePassingFn(paddle.autograd.PyLayer):
    @staticmethod
    @custom_fwd
    def forward(ctx, states, dA_chunk_cumsum, initial_states=None):
        batch, nchunks, nheads, dim = states.shape
        assert tuple(dA_chunk_cumsum.shape) == (batch, nheads, nchunks)
        if states.strides[-1] != 1:
            states = states.contiguous()
        out, final_states = _state_passing_fwd(states, dA_chunk_cumsum, initial_states)
        ctx.save_for_backward(out, dA_chunk_cumsum)
        ctx.has_initial_states = initial_states is not None
        return out, final_states

    @staticmethod
    @custom_bwd
    def backward(ctx, dout, dfinal_states):
        out, dA_chunk_cumsum = ctx.saved_tensor()
        batch, nchunks, nheads, dim = out.shape
        assert tuple(dout.shape) == (batch, nchunks, nheads, dim)
        assert tuple(dA_chunk_cumsum.shape) == (batch, nheads, nchunks)
        assert tuple(dfinal_states.shape) == (batch, nheads, dim)
        if dout.strides[-1] != 1:
            dout = dout.contiguous()
        dstates, ddA_chunk_cumsum, dinitstates = _state_passing_bwd(
            out, dA_chunk_cumsum, dout, dfinal_states=dfinal_states, has_initial_states=ctx.has_initial_states
        )
        return dstates, ddA_chunk_cumsum, dinitstates


def state_passing(states, dA_chunk_cumsum, initial_states=None):
    """
    Argument:
        states: (batch, nchunks, nheads, dim)
        dA_chunk_cumsum: (batch, nheads, nchunks)
        initial_states: (batch, nheads, dim)
    Return:
        out: (batch, nchunks, nheads, dim) final_states: (batch, nheads, dim)
    """
    return StatePassingFn.apply(states, dA_chunk_cumsum, initial_states)


def state_passing_ref(states, dA_chunk_cumsum, initial_states=None):
    """
    Argument:
        states: (batch, nchunks, nheads, dim)
        dA_chunk_cumsum: (batch, nheads, nchunks)
        initial_states: (batch, nheads, dim)
    Return:
        out: (batch, nchunks, nheads, dim) final_states: (batch, nheads, dim)
    """
    if initial_states is None:
        initial_states = paddle.zeros_like(states[:, 0])
    states = paddle.concat([rearrange(initial_states, "b h d -> b 1 h d"), states], axis=1)
    dA_chunk_cumsum = F.pad(dA_chunk_cumsum, (1, 0), data_format="NCL")
    dA_chunk_cumsum = paddle.cumsum(dA_chunk_cumsum, axis=-1)
    nchunks = dA_chunk_cumsum.shape[-1]
    # (batch, nheads, nchunks, nchunks)
    dt_chunk_segment_sum = dA_chunk_cumsum[:, :, :, None] - dA_chunk_cumsum[:, :, None, :]
    # (batch, nheads, nchunks, nchunks)
    decay_chunk = paddle.exp(dt_chunk_segment_sum)
    causal_mask = paddle.tril(paddle.ones([nchunks, nchunks], dtype=bool), diagonal=0)
    decay_chunk = decay_chunk.masked_fill(~causal_mask, 0)
    out = paddle.einsum("bhzc,bchd->bzhd", decay_chunk.cast(dtype=states.dtype), states)
    return out[:, :-1], out[:, -1]
